import spacy
import jsonlines
import random
from tqdm import tqdm
import nlp

spnlp = spacy.load('en_core_web_lg')

splits = ["train","validation","test"]
# splits = ["validation"]

def get_qa_entities(question,answerlist):
    stop_words = set(["a","an","the","how","who","what","which","where","when","is","was","that","there","and","or","any","if","their","your","you"])
    question_entities = []
    qdoc = spnlp(question.lower())
    docverbs=[x.text for x in qdoc if x.pos_=="VERB"]
    question_entities.extend([x.text for x in qdoc.noun_chunks])
    question_entities.extend(docverbs)
    all_ents = [] 
    for ents in question_entities:
        esplits = ents.split(" ")
        esplits = set(esplits)-stop_words
        all_ents.extend(esplits)
        
    question_entities=all_ents
    answerlist = [ x.lower() for x in answerlist]
    answer_entities=[]
    for aix,ans in enumerate(answerlist):
        aents=[]
        ans = ans.lower().replace(".","")
        ans_splits = ans.split(' ')
        if len(ans_splits)==1:
            aents.append(ans)
        else:
            adoc=spnlp(ans)
            ncs = [x.text for x in adoc.noun_chunks]
            ncverbs = [x.text for x in adoc if x.pos_=="VERB"]
            ncs.extend(ncverbs)
            for x in ncs:
                x = x.split(' ')
                x = set(x)-stop_words
                x_all = ' '.join(x)
                aents.append(x_all)
                aents.extend(x)
        answer_entities.append(aents)
    return question_entities,answer_entities

def process_ents(all_ents):
    pro_ents=[]
    for x in all_ents:
        words = x.split(" ")
        pro_ents.append(x)
        pro_ents.extend(words)
    return pro_ents

def extract_obqa():
    obqa = nlp.load_dataset("openbookqa")
    q_ents = []
    a_ents = []
    for split in splits:
        #gives only questions
        for question,ansmap in tqdm(zip(obqa[split]['question_stem'],obqa[split]['choices'])):
            answerlist = ansmap["text"]
            q_e,a_e = get_qa_entities(question,answerlist)
            q_ents.extend(q_e)
            for x in a_e:
                a_ents.extend(x)
#     print(q_ents[0:10],a_ents[0:10])
    q_ents=list(set(process_ents(q_ents)))
    a_ents=list(set(process_ents(a_ents)))
    with open("obqa_qents.csv","w") as ofd:
        for x in q_ents:
            ofd.write(f"{x}\n")
    with open("obqa_aents.csv","w") as ofd:
        for x in a_ents:
            ofd.write(f"{x}\n")
            
def extract_arc():
    obqa = nlp.load_dataset("ai2_arc")
    q_ents = []
    a_ents = []
    for split in splits:
        #gives only questions
        for question,ansmap in tqdm(zip(obqa[split]['question'],obqa[split]['choices'])):
            answerlist = ansmap["text"]
            q_e,a_e = get_qa_entities(question,answerlist)
            q_ents.extend(q_e)
            for x in a_e:
                a_ents.extend(x)
#     print(q_ents[0:10],a_ents[0:10])
    q_ents=list(set(process_ents(q_ents)))
    a_ents=list(set(process_ents(a_ents)))
    with open("arc_qents.csv","w") as ofd:
        for x in q_ents:
            ofd.write(f"{x}\n")
    with open("arc_aents.csv","w") as ofd:
        for x in a_ents:
            ofd.write(f"{x}\n")
            
def extract_comm():
    obqa = nlp.load_dataset("commonsense_qa")
    q_ents = []
    a_ents = []
    for split in splits:
        #gives only questions
        for question,ansmap in tqdm(zip(obqa[split]['question'],obqa[split]['choices'])):
            answerlist = ansmap["text"]
            q_e,a_e = get_qa_entities(question,answerlist)
            q_ents.extend(q_e)
            for x in a_e:
                a_ents.extend(x)
#     print(q_ents[0:10],a_ents[0:10])
    q_ents=list(set(process_ents(q_ents)))
    a_ents=list(set(process_ents(a_ents)))
    with open("comm_qents.csv","w") as ofd:
        for x in q_ents:
            ofd.write(f"{x}\n")
    with open("comm_aents.csv","w") as ofd:
        for x in a_ents:
            ofd.write(f"{x}\n")
    
    
extract_comm()